from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from mooc.core.config import settings
from typing import Generator, Set
import importlib
import pkgutil
from pathlib import Path
import logging

# 配置日志
logger = logging.getLogger(__name__)

# 创建数据库引擎
engine = create_engine(
    settings.SQLALCHEMY_DATABASE_URI,
    pool_pre_ping=True,
    echo=settings.SQLALCHEMY_ECHO
)

# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# 创建基类
Base = declarative_base()

def get_existing_tables() -> Set[str]:
    """获取数据库中已存在的表"""
    inspector = inspect(engine)
    return set(inspector.get_table_names())

def import_models() -> None:
    """
    自动导入所有模型
    这确保所有模型类都被正确地注册到Base.metadata
    """
    models_path = Path(__file__).parent.parent / "models"
    for module_info in pkgutil.iter_modules([str(models_path)]):
        importlib.import_module(f"mooc.models.{module_info.name}")
    
    # 导入后立即验证
    from mooc.models import verify_all_models
    verify_all_models()

def create_missing_tables() -> None:
    """创建缺失的表"""
    existing_tables = get_existing_tables()
    metadata_tables = set(Base.metadata.tables.keys())
    missing_tables = metadata_tables - existing_tables
    
    if missing_tables:
        logger.info(f"Creating missing tables: {missing_tables}")
        # 只创建缺失的表
        for table_name in missing_tables:
            if table_name in Base.metadata.tables:
                Base.metadata.tables[table_name].create(engine)
    else:
        logger.info("All tables already exist")

def init_db() -> None:
    """
    初始化数据库
    1. 导入所有模型并验证
    2. 检查并创建缺失的表
    """
    try:
        # 确保所有模型都被导入并验证
        import_models()
        logger.info("All models imported successfully")
        
        # 创建缺失的表
        create_missing_tables()
        logger.info("Database initialization completed successfully")
        
        # 打印所有已注册的表名（用于调试）
        from mooc.models import get_all_table_names
        logger.debug(f"Registered tables: {get_all_table_names()}")
        
    except Exception as e:
        logger.error(f"Database initialization failed: {str(e)}")
        raise

def get_db() -> Generator:
    """
    获取数据库会话的依赖项
    """
    try:
        db = SessionLocal()
        yield db
    finally:
        db.close()